-
Notifications
You must be signed in to change notification settings - Fork 115
feat: add trigonometric functions #861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: add trigonometric functions #861
Conversation
This just makes the compiler happy and is not yet tested!
|
|
|
Not sure if it might be a good idea to pull all the trigonometry functions (including the existing ones) into a new enum for the IR, because that's a lot of added stuff, and it's now enough to warrant its own category I think. We should at some point also think about applying a similar separation to the compilers themselves, but that's a larger rework that wouldn't go into this PR. I'm currently looking into the MLIR functions for that stuff - the bindings are auto-generated I believe, so if the functions don't exist, they might be in a different namespace or need a different way to handle them. As for the float types, casting to float for unsupported float types (i.e. F16, BF16) is reasonable. There do appear to be double versions of the functions, so that will work natively. |
precision for some of the new trigonometric operations
Must I implement this in this PR or will this be part of another refactor?
Can you please double-check my implementation? I don't really understand at which point the right functions are written to the shader for each dialect... |
I think that would be better to do after all the outstanding work has been merged, so I'll do it separately. #[cube]
pub fn to_degrees<F: Float>(val: F) -> F {
val * F::new(180.0 / f32::PI)
}It would keep it in one place rather than implemented separately for each compiler. |
CUDA is already correct, not sure what the WGSL behaviour even is because it's an experimental extension. It might just support f16 overloads on those functions by default, but I can't yet test it because of some issues with features on Vulkan. |
There are intrinsics for WGSL and SPIR-V (at least rspirv-ext), also Rusts f32 and f64 do have to_degrees and to_radians operations. Also, as a user, I would probably search in the Float Namespace for these functions instead of the cubecl-std. That's why I think having them there would make still sense. I removed the dialect specific compile_instruction_xxx_scalar calls and instead hardcode the operations directly for the cubecl-cpp, since all CUDA, HIP and Metal should convert implicit to the right type and share the same syntax. |
|
Before things are getting stale, I mark this as ready to review. Currently, 21 tests are failing, 20 of them related to the missing CPU implementation of the inverse trigonometric functions (arc-*) and one is that For both problems I can't do more, since the first one needs to be fixed by @marcantoinem I guess? and the second one needs to be fixed by a design decision whether If I should do more, let me know. |
that epsilon looks too tight for f16, might have another bug in the precision-aware comparison like when I forgot an |
|
The code looks fine to me, but the tests don't pass on the CI. @relativityhd We may also remove some functions in the trigo modules, unsure they are necessary. |
|
@nathanielsimard yes the CI is failing because as I mentioned this PR is blocked by the MLIR melior project, because the math module there wont register correctly, as @marcantoinem mentioned.
Sure, just tell which ones I should remove. |
The ones that are not stricly necessary, using if statements is also very bad on most hardware, so we can remove the functions that have if statements. |
|
Sorry for the long wait - my vacation is over and time limited again... I've removed unnecessary trig functions from the std, only keeping hypot, to_radians and to_degrees since they are also present in rusts f32. This PR is still blocked by the MLIR melior project. I still get the following error when testing (I've removed the dummy code and enabled ods math support again, as @marcantoinem described): @nathanielsimard What repository is responsible for this? Maybe I can check and fix it there. I am quite confused about the different LLVM repositories and packages used. |
|
|
So I checked whether I would find any reason why the math dialect is not properly registered, but I just can't find it. Seems to be present everywhere where other dialects are also present, e.g. Arith. |
|
Those are the files for the operands, but you need to include too the new
conversion i mentioned earlier so the operands could be converted to libm
calls otherwise it seems the dialect is unable to be converted. With
CUBECL_MLIR_DEBUG env variable targetting a directory you could actually
see the mlir files generated at each step conversion, for example rsqrt
ends as llvm 1/sqrt.
…On Wed, Nov 5, 2025, 22:11 Tobias Hölzer ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs
<#861 (comment)>:
> @@ -28,6 +28,117 @@ impl<'a> Visitor<'a> {
let result = self.append_operation_with_result(operation);
self.insert_variable(out, result);
}
+ Arithmetic::Sinh(_sinh) => {
Yes and as far as I can see there is nothing which would prevent acosh
etc. from loading, since the dialects are directly imported from mlir:
https://github.com/tracel-ai/tracel-llvm/blob/9d51e26fab8542949c94f7c14c2519113648bf27/crates/tracel-mlir-sys/wrapper.h#L19
and
https://github.com/tracel-ai/tracel-llvm/blob/9d51e26fab8542949c94f7c14c2519113648bf27/crates/tracel-mlir-rs/src/pass/conversion.rs#L35
and
https://github.com/tracel-ai/tracel-llvm/blob/9d51e26fab8542949c94f7c14c2519113648bf27/crates/tracel-mlir-rs/src/dialect/ods.rs#L144
Maybe some files are missing in
https://github.com/tracel-ai/tracel-llvm/blob/9d51e26fab8542949c94f7c14c2519113648bf27/crates/tracel-mlir-rs/src/dialect/ods.rs#L146
?
—
Reply to this email directly, view it on GitHub
<#861 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAIXOWTUVSYTFUL2A4SOG3T33JRZZAVCNFSM6AAAAACFU5DRS6VHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZTIMRUGM4DEMJXGA>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
|
I have added you all to my repo in case I am blocking your development, just in case |
|
I have merged from main, but still no success. Still got the I've tried running with the debug flag you mentioned |
|
You have to activate feature mlir-dump, for example: CUBECL_DEBUG_MLIR=/tmp/mlir2 cargo test --features mlir-dump -p cubecl-cpu tests::f32_ty::unary:: -- --nocapture The order of the registration of the conversion of the dialects is important, the above commit should fix it. |
|
OMG yes it finally works, thank you! So I guess this PR is now ready for merge! @wingertge @nathanielsimard |
cpp-dialects as is is equal in all dialects
|
Sooo, tests failing on my machine, but the functions which are failing are outside this PR (plane-shuffle function). From my end everything should be fine. 👍🏻 |
| function!(FastCos, "__cosf", false); | ||
| function!(Sin, "sin"); | ||
| function!(Cos, "cos"); | ||
| function!(Tan, "tan"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails on CUDA because htan is undefined, I think this should be false for the f16 support
| input: Input, | ||
| _out_elem: Elem<D>, | ||
| ) -> std::fmt::Result { | ||
| write!(f, "{input}*57.29577951308232f") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is broken, you can't multiply a float by half. It should convert to the correct type. Same with radians.
Add trigonometric functions (atan2 etc.)
Adds the following functions for all Floats:
sinh)cosh)asin)asinh)acos)acosh)atan)atanh)atan2)degrees)radians)Open Questions / Missing parts
LLVM MIR implementation
I have setup a placeholder currently with code which I assumed to work commented out.
It seems that support for these functions needs to be added in the tracel-llvm repository
, but I have no clue where.
Non f64/f32
At some point there is limited support for the "special" floats and for e.g. powf which I used as an example a conversion was needed to normal floats.
I am unsure at which points this is necessary and where I can find out whether these conversions are necessary.
Metal safe operations
Since I used the existing sin, cos, tanh and powf functions as examples on how to add the other functions, i stumbled across the implementation for metal for tanh:
I couldn't find anything about this "safe" version in the metal documentation, but I am clearly not an expert.
For which functions is a "safe" implementation needed and which are fine without?
Validate your PR with burn.
It is important that you make sure that you don't introduce any bugs in burn.
Instructions